# Copyright Huawei Technologies Co., Ltd. 2025. All rights reserved.
import argparse
import os

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch

from modelslim.pytorch.llm_ptq.anti_outlier import AntiOutlierConfig, AntiOutlier
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig


CPU = "cpu"
NPU = "npu"
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='')
    parser.add_argument('--calib_images', type=str, default='./calibImages')
    parser.add_argument('--save_directory', type=str, default='')
    parser.add_argument('--part_file_size', type=int, default=None)
    parser.add_argument('--w_bit', type=int, default=8)
    parser.add_argument('--a_bit', type=int, default=8)
    parser.add_argument('--device_type', type=str, choices=[CPU, NPU], default=CPU)
    args = parser.parse_args()

    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    tokenizer.padding_side = 'left'
    tokenizer.pad_token_id = tokenizer.eod_id

    ##1.加载模型
    device_map = CPU if args.device_type == CPU else "auto"
    config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
    dtype = config.torch_dtype if args.device_type == NPU else torch.float32
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path, device_map=device_map, trust_remote_code=True, torch_dtype=torch.float16).eval()


    ##2.设置回退层
    disable_names = [f"transformer.h.{layer}.mlp.c_proj" for layer in range(32)]
    disable_names.append("transformer.visual.attn_pool.attn.out_proj")
    disable_names.append("lm_head")

    ##3.校准集
    images_list = os.listdir(args.calib_images)
    prompt = "USER: <image>\nDescribe this image in detail. ASSISTANT:"
    calib_data = []
    for i in images_list:
        imagePath = os.path.join(args.calib_images, i)
        query = tokenizer.from_list_format([
            {'image': imagePath}, # Either a local path or an url
            {'text': prompt},
        ])
        inputs = tokenizer(query, return_tensors='pt')
        inputs = inputs.to(model.device)
        calib_data.append([inputs['input_ids'], None, inputs['attention_mask']])

    ##4.异常值抑制
    anti_config = AntiOutlierConfig(
        w_bit=args.w_bit,
        a_bit=args.a_bit,
        anti_method="m2",
        dev_type=args.device_type,
        dev_id=model.device.index,
        )
    anti_outlier = AntiOutlier(model, calib_data=calib_data, cfg=anti_config)
    anti_outlier.process()
    ##5.模型量化
    quant_config = QuantConfig(
        w_bit=args.w_bit,
        a_bit=args.a_bit,
        disable_names=disable_names,
        dev_type=args.device_type,
        dev_id=model.device.index,
        act_method=2,
        mm_tensor=False,
    )
    calibrator = Calibrator(model, quant_config, calib_data=calib_data, disable_level='L0')
    calibrator.run()
    ##6.保存权重
    calibrator.save(args.save_directory, save_type=["safe_tensor"], part_file_size=args.part_file_size)